"""
Various utils for data preprocessing

"""
# pylint: disable=anomalous-backslash-in-string
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring
# pylint: disable=no-else-return
import os
import pickle
from time import time
import gc
import random
from copy import copy, deepcopy
import hashlib

import numpy as np
import scipy.sparse as sps
import tables
import torch
import dgl
import dgl.data
import dgl.function as fn
import dgl.sparse as dglsp
from dgl.data import AsNodePredDataset
from ogb.nodeproppred import DglNodePropPredDataset
from dgl.dataloading import MultiLayerFullNeighborSampler


from model import SAGE
from utils import load_train_conf

## Basic utils
# def hash_list(foo):
#     assert isinstance(foo, list)
#     my_foo = copy(foo)
#     my_foo.sort()
#     return hash(tuple(my_foo))

def hash_list(foo):
    assert isinstance(foo, list)
    my_foo = foo[:]
    my_foo.sort()
    my_foo = str(my_foo).encode()
    return hashlib.sha256(my_foo).hexdigest()

def load_data(name, seed=234, check_split=True):
    """load data and basic preprocessing"""
    root_path = os.path.dirname(os.path.abspath(__file__))
    data_path = os.path.join(root_path, "dataset")
    if "cora" == name:
        dataset = dgl.data.CoraGraphDataset()
        g = dataset[0]
        split_dataset(g, seed=seed)
    elif "citeseer" == name:
        dataset = dgl.data.CiteseerGraphDataset()
        g = dataset[0]
        split_dataset(g, seed=seed)
    elif "pubmed" == name:
        dataset = dgl.data.PubmedGraphDataset()
        g = dataset[0]
        split_dataset(g, seed=seed)
    elif "reddit" == name:
        dataset = dgl.data.RedditDataset()
        g = dataset[0]
        # !!! need to use the original split
        split_dataset(g, seed=seed)
    elif "a-computer" == name:
        dataset = dgl.data.AmazonCoBuyComputerDataset()
        g = dataset[0]
        split_dataset(g, seed=seed)
    elif "a-photo" == name:
        dataset = dgl.data.AmazonCoBuyPhotoDataset()
        g = dataset[0]
        split_dataset(g, seed=seed)

    elif name == "ogbn-products":
        dataset = AsNodePredDataset(DglNodePropPredDataset(name, root=data_path))
        g = dataset[0]
        # !!! TODO only add the non-existing edges
        g.train_idx, g.val_idx, g.test_idx = dataset.train_idx, dataset.val_idx, dataset.test_idx
    elif name == "ogbn-arxiv":
        dataset = AsNodePredDataset(DglNodePropPredDataset(name, root=data_path))
        g = dataset[0]
        srcs, dsts = g.all_edges()
        dataset[0].add_edges(dsts, srcs)
        g = g.remove_self_loop().add_self_loop()
        g.train_idx, g.val_idx, g.test_idx = dataset.train_idx, dataset.val_idx, dataset.test_idx
    else:
        raise ValueError(f"Invalid data name: {name}")
    
    g = g.remove_self_loop().add_self_loop()
    g.tag_remove_and_addedd_loop = True
    g.num_classes = dataset.num_classes
    print(f"Refined train size: {len(g.train_idx)}"
          f", val size: {len(g.val_idx)}"
          f", test size: {len(g.test_idx)}.")
    if check_split:
        train_hash = hash_list(g.train_idx.tolist())
        val_hash = hash_list(g.val_idx.tolist())
        test_hash = hash_list(g.test_idx.tolist())
        print(f"Split check: train hash {train_hash}, val hash {val_hash}, test hash {test_hash}")
        label_hash = hash_list(g.ndata["label"][:1000].numpy().tolist())
        print(f"First 1000 label hash is {label_hash}")
    return g

def split_dataset(g, seed=666):
    """
    https://github.com/hwwang55/GCN-LPA/blob/21df23afee0912380ac682b0f80ac244140c33e7/src/data_loader.py#L40
    """
    np.random.seed(seed)
    n_samples = g.num_nodes()
    val_indices = np.random.choice(list(range(n_samples)), size=int(n_samples * 0.2), replace=False)
    left = set(range(n_samples)) - set(val_indices)
    test_indices = np.random.choice(list(left), size=int(n_samples * 0.2), replace=False)
    train_indices = list(left - set(test_indices))
    g.train_idx = torch.tensor(train_indices)
    g.val_idx = torch.tensor(val_indices)
    g.test_idx = torch.tensor(test_indices) 

def get_index(data):
    """Data preprocessing generate index according to sparse mask, torch env."""
    def mask_to_ind(mask):
        """mask_to_ind"""
        return torch.tensor([i for i, flag in enumerate(mask) if flag])
    data.train_idx = mask_to_ind(data.train_mask)
    data.val_idx = mask_to_ind(data.val_mask)
    data.test_idx = mask_to_ind(data.test_mask)

def create_partial_identity(idx, total_len):
    my_len = len(idx)
    vals = np.ones(my_len)
    rows = idx
    cols = np.arange(my_len)
    return sps.coo_matrix((vals, (rows, cols)), shape=(total_len, my_len))

def torch_sparse_to_scipy_coo(x):
    """
    Convet torch sparse tensor into scipy coo sparse
    """
    values = x.coalesce().values().detach().numpy()
    indices = x.coalesce().indices().detach().numpy()
    shape = list(x.shape)
    return sps.coo_matrix((values, indices), shape=shape)

def scipy_coo_to_torch_sparse(x):
    """
    Convert scipy coo sparse matrix into torch sparse tensor
    """
    values = x.data
    indices = np.vstack((x.row, x.col))

    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    shape = x.shape

    return torch.sparse.FloatTensor(i, v, torch.Size(shape))

def dgl_extract_sparse_tensor_graph(g, edge_weight_name):
    values = g.edata[edge_weight_name].detach()
    indices = torch.vstack(g.edges())
    shape = g.adj().shape
    
    return torch.sparse.FloatTensor(indices, values, shape)

def norm_randomwalk_adj(adj):
    """Random walk normalize adjacency matrix."""
    adj = sps.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv = np.power(rowsum, -1).flatten()
    d_inv[np.isinf(d_inv)] = 0.
    d_inv_mat = sps.diags(d_inv)
    return d_inv_mat.dot(adj).tocoo()

def norm_symetric_adj(adj):
    """Symmetrically normalize adjacency matrix."""
    adj = sps.coo_matrix(adj)
    rowsum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(rowsum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sps.diags(d_inv_sqrt)
    return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()

def row_norm(x):
    """row normalization"""
    row_sum = np.sum(x, axis=1)
    row_sum_inv = np.power(row_sum, -1).flatten()
    row_sum_inv[np.isinf(row_sum_inv)] = 0
    return x * row_sum_inv.reshape([-1 , 1])

def sigmoid_np_with_torch(target, t, device="cpu"):
    my_target = torch.from_numpy(target).to(device) * t
    sig = torch.nn.Sigmoid()
    return sig(my_target).cpu().detach().numpy()

def pickle_dump(path_out, X, var_name=""):
    with open(path_out, 'wb') as fout:
        pickle.dump(X, fout, protocol=4)
    print(f"Cached {var_name} in {path_out}.")

def pickle_load(path_in, var_name=""):
    with open(path_in, 'rb') as fin:
        res = pickle.load(fin)
    print(f"Loaded {var_name} from {path_in}.")
    return res

def h5_store_csr(M, name, path):
    """
    Store a csr matrix in HDF5

    Parameters
    ----------
    M : scipy.sparse.csr.csr_matrix
        sparse matrix to be stored

    name: str
        node prefix in HDF5 hierarchy

    path: str
        HDF5 file path
        
    Source: https://stackoverflow.com/questions/11129429/storing-numpy-sparse-matrix-in-hdf5-pytables
    """
    assert(M.__class__ == sps.csr.csr_matrix), 'M must be a csr matrix'
    with tables.open_file(path, 'a') as f:
        for attribute in ('data', 'indices', 'indptr', 'shape'):
            full_name = f'{name}_{attribute}'

            # remove existing nodes
            try:  
                n = getattr(f.root, full_name)
                n._f_remove()
            except AttributeError:
                pass

            # add nodes
            arr = np.array(getattr(M, attribute))
            atom = tables.Atom.from_dtype(arr.dtype)
            ds = f.create_carray(f.root, full_name, atom, arr.shape)
            ds[:] = arr
    print(f"Dumped {name} into {path}")

def h5_load_csr(name, path):
    """
    Load a csr matrix from HDF5

    Parameters
    ----------
    name: str
        node prefix in HDF5 hierarchy

    path: str
        HDF5 file path

    Returns
    ----------
    M : scipy.sparse.csr.csr_matrix
        loaded sparse matrix
        
    Source: https://stackoverflow.com/questions/11129429/storing-numpy-sparse-matrix-in-hdf5-pytables
    """
    with tables.open_file(path) as f:

        # get nodes
        attributes = []
        for attribute in ('data', 'indices', 'indptr', 'shape'):
            attributes.append(getattr(f.root, f'{name}_{attribute}').read())

    # construct sparse matrix
    M = sps.csr_matrix(tuple(attributes[:3]), shape=attributes[3])
    print(f"Loaded {name} from {path}")
    return M

def h5_store_np_array(M, name, path):
    """
    """
    with tables.open_file(path, 'a') as f:
        # add nodes
        arr = M
        atom = tables.Atom.from_dtype(arr.dtype)
        ds = f.create_carray(f.root, name, atom, arr.shape)
        ds[:] = arr
    print(f"Dumped {name} into {path}")

def h5_load_np_array(name, path):
    """
    """
    with tables.open_file(path) as f:
        # get nodes
        M = getattr(f.root, f'{name}').read()
    print(f"loaded {name} from {path}")
    return M

class my_lil_matrix(sps.lil_matrix):
    """
    Initialize a lil matrix with given data and rows. 
    https://stackoverflow.com/questions/24192876/initialize-lil-matrix-given-data-and-coordinates
    """
    def __init__(self, *args, **kwargs):
        if len(args) >= 2 and isinstance(args[1], list):
            try:
                data, rows = args
            except:
                raise TypeError("Invalid input format")

            if kwargs.get('shape') is None:
                # Column count will just be the largest value in rows
                M, N = (len(rows), max(np.array(rows).max()) + 1)
            else:
                M, N = shape

            super(my_lil_matrix, self).__init__((M, N), **kwargs)

            self.data = np.array(data, copy=kwargs.get('copy'),
                                 dtype=kwargs.get('dtype'))
            self.rows = np.array(rows, copy=kwargs.get('copy'),
                                 dtype=kwargs.get('dtype'))
        else:
            super(my_lil_matrix, self).__init__(*args, **kwargs)

class graphTester:
    """
    Show simple statistics of graphs and samples, to analyze the receptive nodes

    Gather the nodes from multiple hops
    Show the breakdown overlaps between other receptive nodes and the breakdown neighbourhood
    """
    def __init__(self, g, layer=3, sample=100, sample_seed=0, device="cpu"):
        self.g = g
        self.layer = layer
        self.sample = sample
        self.sample_seed = sample_seed
        self.device = device

        # determine the node ids to process
        self.sample_idx, self.sample_hash = self.get_sample_node(self.g, self.sample, self.sample_seed)

    def gen_breakdown_base(self):
        # generate the breakdown node wise neighbour lists
        self.inclusive_breakdown_neighbour = self.gen_inclusive_neighbour(self.g, self.sample_idx, self.layer)
        self.exclusive_breakdown_neighbour = self.inclusive_2_exclusive(self.inclusive_breakdown_neighbour)

    def get_sample_node(self, g, sample, sample_seed):
        num_nodes = g.num_nodes()
        to_eval_idx = np.arange(num_nodes)
        if sample > 0:
            np.random.seed(sample_seed)
            np.random.shuffle(to_eval_idx)
            to_eval_idx = to_eval_idx[:sample]
        hash_str = hash_list(to_eval_idx.tolist())
        print(f"Hash string for sampled index is {hash_str}.")
        return to_eval_idx, hash_str
        
    def gen_inclusive_neighbour(self, g, sample_idx, layer):
        def sample_one(g, idx, sampler, layer):
            cur_list = []
            cur_seed = [idx]
            for i in range(layer):
                cur_seed = sampler.sample(g, cur_seed)[0].numpy()
                cur_list.append(set(cur_seed))
            return cur_list
        sampler = MultiLayerFullNeighborSampler(1)
        res_list = []
        for idx in sample_idx:
            res_list.append(sample_one(g, idx, sampler, layer))
        return res_list
    
    def inclusive_2_exclusive(self, neighbour_list):
        res_list = []
        for i, each_res in enumerate(neighbour_list):
            cur_list = []
            cur_list.append(deepcopy(each_res[0]))
            for j in range(1, len(each_res)):
                cur_list.append(each_res[j] - each_res[j-1])
            res_list.append(cur_list)
        return res_list
    
    def get_mean(self, neighbour_list):
        layer = len(neighbour_list[0])
        res_list = []
        for i in range(layer):
            res_list.append(np.mean([len(node[i]) for node in neighbour_list]))
        print(",".join([f"{i:.1f}" for i in res_list]))
        return res_list
    
    def csr_idx_receptive_overlap(self, idx, csr_A, cur_neighbours, exclusive=True):
        all_idx = set(csr_A.indices[csr_A.indptr[idx]:csr_A.indptr[idx+1]])
        res = []
        for i in range(self.layer):
            this_layer = all_idx.intersection(cur_neighbours[i])
            res.append(this_layer)
            if exclusive:
                all_idx = all_idx - this_layer
        if not exclusive:
            all_idx = all_idx - this_layer
        res.append(all_idx)
        return res

    def csr_receptive_overlap(self, csr_A, exclusive=True):
        res = []
        for i, idx in enumerate(self.sample_idx):
            if exclusive:
                cur_neighbours = self.exclusive_breakdown_neighbour[i]
            else:
                cur_neighbours = self.inclusive_breakdown_neighbour[i]
            res.append(self.csr_idx_receptive_overlap(idx, csr_A,
                                                      cur_neighbours,
                                                      exclusive=exclusive))
        return res
    
    def disp_neigh_stats(self, exclusive=True):
        if exclusive:
            means = self.get_mean(self.exclusive_breakdown_neighbour)
        else:
            means = self.get_mean(self.inclusive_breakdown_neighbour)
        return means

    def disp_overlap(self, csr_A, exclusive=True):
        csr_A_overlap = self.csr_receptive_overlap(csr_A, exclusive=exclusive)
        means = self.get_mean(csr_A_overlap)
        return csr_A_overlap, means
    
    def csr_receptive_label_state(self, csr_A):
        def csr_row_wise_compare(csr_matrix, target):
            assert csr_matrix.shape[0] == len(target)
            row_cnt = csr_matrix.shape[0]
            my_csr = deepcopy(csr_matrix)
            data = my_csr.data
            indptr = my_csr.indptr
            for i in range(row_cnt):
                data[indptr[i]:indptr[i+1]] = data[indptr[i]:indptr[i+1]] == target[i]
            return my_csr

        csr_A_sub = csr_A[self.sample_idx]
        csr_A_label = deepcopy(csr_A_sub)
        all_labels = self.g.ndata['label'].detach().cpu().numpy()
        csr_A_label.data = all_labels[csr_A_label.indices]
        target_label = all_labels[self.sample_idx]
        row_size = csr_A_sub.indptr[1:] - csr_A_sub.indptr[:-1]
        row_weight = np.sum(csr_A_sub, axis=1).flatten()
        # label_fit_matrix = csr_A_label == target_label.reshape([-1, 1])
        label_fit_matrix = csr_row_wise_compare(csr_A_label, target_label)
        label_fit_size = np.sum(label_fit_matrix, axis=1).flatten()
        weighted_label_fit_size = np.sum(csr_A_sub.multiply(label_fit_matrix), axis=1).flatten()

        label_fit_ratio = label_fit_size / row_size
        weighted_label_fit_ratio = weighted_label_fit_size / row_weight
        return label_fit_ratio, weighted_label_fit_ratio


class graphSampler:
    """
    From dgl graph dataset, sampling weighted and diaised sparse graphs
    In dgl, the row of the matrix represent the out edges and column represent the in edges.
    """
    def __init__(self,
                 g,
                 batch_size = 256,
                 normalize='both',
                 device='cpu'):
        self.g = g
        self.batch_size = batch_size
        self.normalize = normalize
        self.device = device

        self.adj_name = "norm_A_" + self.normalize
        self.prob_name = "p_" + self.normalize
        self.importance_rate_name = "important_norm_" + self.normalize

        if (not hasattr(self.g, 'normalize')) or self.normalize not in self.g.normalize:
            tic = time()
            print("Generating normalized adj...")
            self.normalize_adj()
            print(f"normalized adj finished in {time()-tic:.1f} s")
        else:
            print("Dataset contains normalized weight, skipped normalization.")

        if self.prob_name in self.g.edata:
            print("Dataset contains probability, skipped probability generation.")
        else:
            tic = time()
            print("Creating importance probability...")
            self.create_probability()
            print(f"Probability matrix generated in {time()-tic:.1f} s")
    
    def normalize_adj(self):
        g = self.g
        g.edata[self.adj_name] = torch.from_numpy(np.ones(g.num_edges())).float().to(g.device)
        if self.normalize in ["left", "both"]:
            degs = g.out_degrees().clamp(min=1)
            if self.normalize == "both":
                g.srcdata['l_norm'] = torch.pow(degs, -0.5)
            else:
                g.srcdata['l_norm'] = 1.0 / degs
            g.apply_edges(fn.e_mul_u(self.adj_name, 'l_norm', self.adj_name))
        if self.normalize in ["right", "both"]:
            degs = g.in_degrees().clamp(min=1)
            if self.normalize == "both":
                g.dstdata['r_norm'] = torch.pow(degs, -0.5)
            else:
                g.dstdata['r_norm'] = 1.0 / degs
            g.apply_edges(fn.e_mul_v(self.adj_name, 'r_norm', self.adj_name))

        if (not hasattr(self.g, 'normalize')):
            self.g.normalize = [self.normalize]
        else:
            self.g.normalize.apend(self.normalize)
        return 
    
    def extract_normed_scipy_coo_adj(self, g, edge_weight):
        # adj = torch_sparse_to_scipy_coo(g.adj().detach().cpu())
        # adj.data = g.edata[edge_weight].detach().cpu().numpy()
        indices = g.adj_sparse(fmt="coo")
        indices = [x.detach().cpu().numpy() for x in indices]
        values = g.edata[edge_weight].detach().cpu().numpy()
        adj = sps.coo_matrix((values, indices), shape=(g.num_nodes(), g.num_nodes()))
        return adj
    
    def extract_all_edges_dense(self, g):
        res = {}
        for k, v in g.edata.items():
            res[k] = self.extract_normed_scipy_coo_adj(g, k).toarray()
        return res
    
    def create_probability(self):
        """
        Due to dgl convention, each row of adj is the vector of out edges. The aggregation
        is on the columns. We need to adjust the oriantation.

        p_{iz} \prop A_{iz} l2_norm(A_{*i})
        """
        # get l2 norm of the normed adjacent matrix
        self.g.edata['s'] = torch.pow(self.g.edata[self.adj_name], 2)
        self.g.update_all(fn.copy_e('s', 'm'), fn.sum('m', 'l2_norm'))
        self.g.ndata['l2_norm'] = torch.pow(self.g.ndata['l2_norm'], 0.5)
        # compute the probablities
        self.g.edata[self.prob_name] = self.g.edata[self.adj_name].detach().clone()
        self.g.apply_edges(fn.e_mul_u(self.prob_name, 'l2_norm', self.prob_name))
        self.g.update_all(fn.copy_e(self.prob_name, 'm'), fn.sum('m', 'p_l1_norm'))
        self.g.ndata['p_l1_norm'] = torch.nan_to_num(1.0 / self.g.ndata['p_l1_norm'], posinf=1, neginf=-1)
        self.g.apply_edges(fn.e_mul_v(self.prob_name, 'p_l1_norm', self.prob_name))

        # create the importance norm
        self.g.edata[self.importance_rate_name] = self.g.edata[self.adj_name] *\
            torch.nan_to_num(1.0 / self.g.edata[self.prob_name], posinf=1, neginf=-1)
        return
    
    def plain_conv(self, block, feat, fanout, device="cpu"):
        tch_adj = dgl_extract_sparse_tensor_graph(block, self.adj_name)\
                  .coalesce().to(device)
        return torch.sparse.mm(feat, tch_adj) 
    
    def gen_plain_random_walk(self, fanout=[10, 10, 10]):
        g = self.g
        neighbour_sampler = dgl.dataloading.NeighborSampler(fanout, replace=True)
        loop_loader = dgl.dataloading.DataLoader(g, g.nodes(), neighbour_sampler, device=g.device,
                                                 batch_size=self.batch_size, shuffle=True,
                                                 drop_last=False, num_workers=0,
                                                 use_uva=False)
        t_start = time()
        res_node_idx, res_node_embed = [], []
        total_batch = int(np.ceil(g.num_nodes() / self.batch_size))
        print("", end="")
        for i, (input_nodes, output_nodes, blocks) in enumerate(loop_loader):
            # conv
            feat = create_partial_identity(input_nodes, g.num_nodes())
            feat = scipy_coo_to_torch_sparse(feat).float().to(self.device)
            for j, each_block in enumerate(blocks):
                layer_res = self.plain_conv(each_block, feat, fanout[j], self.device)
                feat = layer_res
            res_node_idx.append(output_nodes)
            res_node_embed.append(layer_res.cpu())
            print(f"\r    Random walk time cost {time()-t_start:.1f},"
                  f" ETA:{(time()-t_start)/(i+1)*(total_batch-i-1):.1f}",
                  end="",
                  flush=True)
        print()
        res_node_idx = torch.concat((res_node_idx)).cpu().detach().numpy()
        res = torch_sparse_to_scipy_coo(torch.hstack(res_node_embed).t().cpu()).tocsr()
        res_node_map = np.argsort(res_node_idx)
        return res[res_node_map]

    def weight_conv(self, block, feat, fanout, device="cpu"):
        tch_adj = dgl_extract_sparse_tensor_graph(block, self.importance_rate_name)\
                  .coalesce().to(device)
        return torch.sparse.mm(feat, tch_adj) / fanout

    def gen_approx_random_walk(self, fanout=[10, 10, 10], p=None):
        """
        The convention is different from dgl, the output matrix is the row-wise in edges
        """
        g = self.g
        neighbour_sampler = dgl.dataloading.NeighborSampler(fanout, prob=p, replace=True)
        loop_loader = dgl.dataloading.DataLoader(g, g.nodes(), neighbour_sampler, device=g.device,
                                                 batch_size=self.batch_size, shuffle=True,
                                                 drop_last=False, num_workers=0,
                                                 use_uva=False)
        t_start = time()
        res_node_idx, res_node_embed = [], []
        total_batch = int(np.ceil(g.num_nodes() / self.batch_size))
        print("", end="")
        for i, (input_nodes, output_nodes, blocks) in enumerate(loop_loader):
            # conv
            feat = create_partial_identity(input_nodes, g.num_nodes())
            feat = scipy_coo_to_torch_sparse(feat).float().to(self.device)
            for j, each_block in enumerate(blocks):
                layer_res = self.weight_conv(each_block, feat, fanout[j], self.device)
                feat = layer_res
            res_node_idx.append(output_nodes)
            res_node_embed.append(layer_res.cpu())
            print(f"\r    Random walk time cost {time()-t_start:.1f},"
                  f" ETA:{(time()-t_start)/(i+1)*(total_batch-i-1):.1f}",
                  end="",
                  flush=True)
        print()
        res_node_idx = torch.concat((res_node_idx)).cpu().detach().numpy()
        res = torch_sparse_to_scipy_coo(torch.hstack(res_node_embed).t().cpu()).tocsr()
        res_node_map = np.argsort(res_node_idx)
        return res[res_node_map]
    
    def validate_h_init_theta_sampling(self, sample_cnt=10, fanout=[10,10,10], weighted=True):
        # compute the true h by matrix multiplicatoin
        normed_A = scipy_coo_to_torch_sparse(self.extract_normed_scipy_coo_adj(self.g, self.adj_name)).to(self.device)
        layer = len(fanout)
        true_res = normed_A.detach()
        for i in range(layer-1):
            true_res = torch.sparse.mm(true_res, normed_A)
        true_res = torch_sparse_to_scipy_coo(true_res.cpu()).toarray().transpose()
        # compute the sample for a lot of times
        sampled_list = []
        for i in range(sample_cnt):
            if weighted:
                sampled_list.append(self.gen_approx_random_walk(fanout=fanout, p=self.prob_name).toarray())
            else:
                sampled_list.append(self.gen_plain_random_walk(fanout=fanout).toarray())

        # compute the error
        mean_sample = np.mean(sampled_list, axis=0)
        error = np.linalg.norm(mean_sample - true_res) / np.linalg.norm(true_res)
        return torch_sparse_to_scipy_coo(normed_A.cpu()).toarray(), true_res, sampled_list, mean_sample, error
    
    # def weight_conv_dgl(self, block, feat):
    #     """
    #     Conv layer for the weight computation with dgl message passing
    #     """
    #     block.srcdata['h'] = feat
    #     msg_fn = fn.e_mul_u(self.importance_rate_name, 'h', 'm')
    #     block.update_all(msg_fn, fn.mean('m', 'neigh'))
    #     return block.dstdata['neigh']

    # def gen_approx_random_walk_dgl(self, fanout=[10, 10, 10], p=None):
    #     """
    #     The convention is different from dgl, the output matrix is the row-wise in edges
    #     """
    #     g = self.g
    #     neighbour_sampler = dgl.dataloading.NeighborSampler(fanout, prob=p, replace=True)
    #     loop_loader = dgl.dataloading.DataLoader(g, g.nodes(), neighbour_sampler, device=g.device,
    #                                              batch_size=self.batch_size, shuffle=True,
    #                                              drop_last=False, num_workers=0,
    #                                              use_uva=False)
    #     res_row_dict = {}
    #     t_start = time()
    #     total_batch = int(np.ceil(g.num_nodes() / self.batch_size))
    #     print("", end="")
    #     for i, (input_nodes, output_nodes, blocks) in enumerate(loop_loader):
    #         # conv
    #         feat = create_partial_identity(input_nodes, g.num_nodes())
    #         feat = torch.from_numpy(feat).float().to(g.device)
    #         for each_block in blocks:
    #             layer_res = self.weight_conv_dgl(each_block, feat)
    #             feat = layer_res
    #         res_row_dict.update({k.detach().cpu().numpy().item() : sps.csr_matrix(v.detach().cpu().numpy())
    #                              for k, v in zip(output_nodes, layer_res)})
    #         print(f"\r    Random walk time cost {time()-t_start:.1f},"
    #               f" ETA:{(time()-t_start)/(i+1)*(total_batch-i-1):.1f}",
    #               end="",
    #               flush=True)
    #         print()
    #     # transform the row dictionary into sparse matrix
    #     list_row = [v for _, v in sorted(res_row_dict.items(), key=lambda item: item[0])]
    #     return sps.vstack(list_row)


class SDMPDataPre: # pylint: disable=too-many-instance-attributes
    """
    Prepare the data for SDMP tests. theta is thetaT in paper, same for the config files.
    """
    def __init__(self, data_name, feature_normalize, # pylint: disable=too-many-arguments
                 h_model_mode, h_conf_path, h_model_path, target_h_model, 
                 h_init_theta_mode, h_init_theta_k, h_init_theta_k_fanout,
                 theta_cand_mode, theta_cand_k2, theta_cand_k1, theta_cand_fanout,
                 theta_cand_add_self,
                 train_conf,
                 use_cache=True, cache_path=None, save_cache=True, device="cpu"):
        self.data_name = data_name
        self.feature_normalize = feature_normalize
        self.h_model_mode = h_model_mode
        self.h_conf_path, self.h_model_path, self.target_h_model =\
            h_conf_path, h_model_path, target_h_model
        self.h_init_theta_mode = h_init_theta_mode
        self.h_init_theta_k = h_init_theta_k
        self.h_init_theta_k_fanout = h_init_theta_k_fanout
        self.theta_cand_mode = theta_cand_mode
        self.theta_cand_k2 = theta_cand_k2
        self.theta_cand_k1 = theta_cand_k1
        self.theta_cand_fanout = theta_cand_fanout
        self.theta_cand_add_self = theta_cand_add_self
        self.train_conf = train_conf
        self.use_cache = use_cache
        self.cache_path = cache_path
        self.save_cache = save_cache
        self.device = device
        self.validate_params()
        
        self.h_model_root = os.path.dirname(self.h_model_path)

        if self.use_cache:
            self.A_pow_path, self.theta_cand_path, self.h_init_theta_path,\
                self.X_path, self.target_path = self.get_paths()

        self.g = None
        self.theta_cand, self.h_init_theta, self.X, self.target = self.process()

    
    def validate_params(self):
        if self.h_init_theta_mode == "sparse":
            if isinstance(self.h_init_theta_k_fanout, list):
                assert len(self.h_init_theta_k_fanout) == self.h_init_theta_k
            else:
                self.h_init_theta_k_fanout = [self.h_init_theta_k_fanout] * self.h_init_theta_k
        if self.theta_cand_mode in ["sparse", "mixed"]:
            if isinstance(self.theta_cand_fanout, list):
                assert len(self.theta_cand_fanout) == self.theta_cand_k2
            else:
                self.theta_cand_fanout = [self.theta_cand_fanout] * self.theta_cand_k2
            if self.theta_cand_mode == "mixed":
                assert self.theta_cand_k2 >= self.theta_cand_k1
    
    def get_A_pow(self, k, use_cache=True):
        """
        To get A power 
        """
        A_path = os.path.join(self.cache_path,
                              f"A_pow_{k}")
        if use_cache and os.path.isfile(A_path):
            return h5_load_csr("A_pow", A_path)
        sampler = graphSampler(self.g, batch_size = 256, normalize='both', device=self.device)
        A = sampler.extract_normed_scipy_coo_adj(self.g, sampler.adj_name)
        A = A.transpose().tocsr()
        A = self.gen_adj_power_layerwise(k, A)
        if use_cache:
            h5_store_csr(A, "A_pow", A_path)
        return A
    
    def get_A_pow_local_indices(self, k, indices):
        sampler = graphSampler(self.g, batch_size = 256, normalize='both', device=self.device)
        A = sampler.extract_normed_scipy_coo_adj(self.g, sampler.adj_name)
        cur_rows = A.tocsr()[indices, :]
        cur_rows = scipy_coo_to_torch_sparse(cur_rows.tocoo()).to(self.device)
        torch_A = scipy_coo_to_torch_sparse(A).to(self.device)

        for _ in range(k-1):
            cur_rows = torch.sparse.mm(cur_rows, torch_A)
        return torch_sparse_to_scipy_coo(cur_rows.detach().cpu()).tocsr()
        
    def get_paths(self):
        """
        Define the cache paths
        """
        root_path = os.path.join(self.cache_path, self.target_h_model)

        if not os.path.isdir(root_path):
            os.makedirs(root_path)
        self.root_path = root_path
        
        A_pow_path = os.path.join(self.cache_path,
                                  f"A_pow_{self.theta_cand_k1}")

        theta_cand_path = os.path.join(root_path,
                                       f"theta_{self.theta_cand_mode}")
        if self.theta_cand_mode in ["mixed", "dense"]:
            theta_cand_path += f"_k1_{self.theta_cand_k1}"
        if self.theta_cand_mode in ["sparse", "mixed"]:
            theta_cand_path += f"_k2_{self.theta_cand_k2}"+\
                               f"_{'_'.join([str(i) for i in self.theta_cand_fanout])}"
        if self.theta_cand_add_self:
            theta_cand_path += "_withself"
        else:
            theta_cand_path += "_noself"

        
        h_init_theta_path = os.path.join(root_path,
                                         f"h_init_theta_{self.h_init_theta_mode}"
                                         f"_{self.h_init_theta_k}")
        if self.h_init_theta_mode == "sparse":
            h_init_theta_path += f"_{'_'.join([str(i) for i in self.h_init_theta_k_fanout])}"
        if self.theta_cand_add_self:
            h_init_theta_path += "_withself"
        else:
            h_init_theta_path += "_noself"
        
        X_path = os.path.join(root_path, f"features_normed_{self.feature_normalize}")
        target_path = "target_cache_SDMP_"+ os.path.basename(self.h_model_path).split(".")[0] +\
            "_"+ self.train_conf["target_normalize"] + "_" +\
            str(self.train_conf["target_normalize_param"]) +"_.h5"
        target_path = os.path.join(self.h_model_root, target_path)
        
        return A_pow_path, theta_cand_path, h_init_theta_path, X_path, target_path

    def process(self): 
        tic_all = time()
        # check the files
        if self.use_cache:
            has_theta = os.path.isfile(self.theta_cand_path)
            has_h = os.path.isfile(self.h_init_theta_path)
            has_X = os.path.isfile(self.X_path)
            has_target = os.path.isfile(self.target_path)

        # load and preprocess common modules
        if not (self.use_cache and has_theta and has_h and has_X and has_target):
            self.g = load_data(self.data_name)

        # theta
        if self.use_cache and has_theta:
            # theta_cand = pickle_load(self.theta_cand_path, var_name="theta")
            theta_cand = h5_load_csr("theta", self.theta_cand_path)
        else:
            # if self.use_cache and has_A_pow:
            #     # A_pow = pickle_load(self.A_pow_path, var_name="A_pow")
            #     A_pow = h5_load_csr("A_pow", self.A_pow_path)
            # else:
            #     A_pow = None
            # theta_cand = self.process_gen_random_walk_matrix(self.theta_cand_mode,
            #                                                  A_pow=A_pow,
            #                                                  k2_fanout=self.theta_cand_fanout,
            #                                                  k1=self.theta_cand_k1)
            # if self.theta_cand_add_self:
            #     theta_cand += create_partial_identity(range(self.g.num_nodes()),
            #                                           self.g.num_nodes()).tocsr()
            theta_cand = self.gen_plain_random_walk_matrix(self.theta_cand_mode,
                                                           k1=self.theta_cand_k1,
                                                           k2_fanout=self.theta_cand_fanout)
            
            # theta_cand = theta_cand.tolil()
            if self.save_cache:
                # pickle_dump(self.theta_cand_path, theta_cand, var_name="theta")
                h5_store_csr(theta_cand, "theta", self.theta_cand_path)
        # h
        if self.use_cache and has_h:
            # h_init_theta = pickle_load(self.h_init_theta_path, var_name="h_init_theta")
            h_init_theta = h5_load_csr("h_init_theta", self.h_init_theta_path)
        else:
            h_init_theta = self.process_gen_random_walk_matrix(self.h_init_theta_mode,
                                                               k2_fanout=self.h_init_theta_k_fanout,
                                                               k1=self.h_init_theta_k)
            # h_init_theta = h_init_theta.tolil()
            if self.save_cache:
                # pickle_dump(self.h_init_theta_path, h_init_theta, var_name="h_init_theta")
                h5_store_csr(h_init_theta, "h_init_theta", self.h_init_theta_path)
        # X
        if self.use_cache and has_X:
            # X = pickle_load(self.X_path, var_name="X")
            X = h5_load_np_array("X", self.X_path)
        else:
            X = self.get_X(self.g, self.feature_normalize)
            if self.save_cache:
                # pickle_dump(self.X_path, X, var_name="X")
                h5_store_np_array(X, "X", self.X_path)
        # target
        if self.use_cache and has_target:
            # target = pickle_load(self.target_path, var_name="target")
            target = h5_load_np_array("target", self.target_path)
        else:
            # process the target
            if self.h_model_mode == "internal":
                target = self.get_SAGE_embed()
            elif self.h_model_mode == "external":
                target = np.load(self.h_model_path)['arr_0']
            else:
                raise ValueError(f"Unrecognized target h model mode {self.h_model_mode}.")
            # normalization
            if self.train_conf["target_normalize"] == "no":
                pass
            elif self.train_conf["target_normalize"] == "sigmoid":
                target = sigmoid_np_with_torch(target, self.train_conf["target_normalize_param"],device=self.device)
            else:
                raise ValueError(f"Unrecognized target normalizaiton {self.train_conf['target_normalize']}")
            if self.save_cache:
                # pickle_dump(self.target_path, target, var_name="target")
                h5_store_np_array(target, "target", self.target_path)
        print(f"Process SAGE finished in {time() - tic_all:.1f} s.")

        return theta_cand, h_init_theta, X, target
    
    def gen_plain_random_walk_matrix(self, mode, k1=None, k2_fanout=None, batch_size=256):
        """
        Fast sampling only care receptive nodes, with plain randomwalk
        """
        if "tag_remove_and_addedd_loop" not in self.g.__dict__:
            raise TypeError("Graph did not remove and added selfloop. Dense algorithm will not work for this case!!")
        sampler = graphSampler(self.g, batch_size = batch_size, normalize='both', device=self.device)
        g = self.g
        shape = (g.num_nodes(), g.num_nodes())
        if mode in ["dense", "mixed"]:
            if self.use_cache and os.path.isfile(self.A_pow_path):
                dense_neigh = h5_load_csr("A_pow", self.A_pow_path)
            else:
                A = sampler.extract_normed_scipy_coo_adj(g, sampler.adj_name)
                A = A.transpose().tocsr()
                dense_neigh = self.gen_adj_power_layerwise(k1, A)
                if self.save_cache:
                    h5_store_csr(dense_neigh, "A_pow", self.A_pow_path)
        else:
            dense_neigh = sps.csr_matrix(shape)
        if mode in ["sparse", "mixed"]:
            sparse_neigh = sampler.gen_plain_random_walk(fanout=k2_fanout)
        else:
            sparse_neigh = sps.csr_matrix(shape)
            
        return dense_neigh + sparse_neigh
            
    def process_gen_random_walk_matrix(self,
                                       mode,
                                       A_pow=None,
                                       k2_fanout=None,
                                       k1=None,
                                       batch_size=256,
                                       normalize='both'):
        """Sampling with importance sampling, variance reduction and correct expectation"""
        sampler = graphSampler(self.g, batch_size = batch_size, normalize=normalize, device=self.device)
        g = self.g
        shape = (g.num_nodes(), g.num_nodes())
        if mode in ["sparse", "mixed"]:
            sparse_neigh = sampler.gen_approx_random_walk(fanout=k2_fanout,
                                                          p=sampler.prob_name)
        else:
            sparse_neigh = sps.csr_matrix(shape)
        if mode in ["dense", "mixed"]:
            if A_pow is not None:
                dense_neigh = A_pow
            else:
                A = sampler.extract_normed_scipy_coo_adj(g, sampler.adj_name)
                A = A.transpose().tocsr()
                # dense_neigh = self.gen_adj_power_minibatch(k1, A)
                dense_neigh = self.gen_adj_power_layerwise(k1, A)
                if self.use_cache:
                    # pickle_dump(self.A_pow_path, dense_neigh, var_name="A_pow")
                    h5_store_csr(dense_neigh, "A_pow", self.A_pow_path)
        else:
            dense_neigh = sps.csr_matrix(shape)
        return sparse_neigh + dense_neigh
    
    def get_X(self, g, feature_normalize):
        print("Preprocessing X.")
        features_compute_device = g.ndata['feat'].to(self.device)
        if feature_normalize == "standard":
            features_compute_device =\
                (features_compute_device-features_compute_device.mean(dim=0))\
                /features_compute_device.std(dim=0)
            X = features_compute_device.detach().cpu().numpy()
        elif feature_normalize == "row_sum":
            X = row_norm(features_compute_device.detach().cpu().numpy())
        elif feature_normalize == "no":
            X = g.ndata['feat'].detach().cpu().numpy()
        else:
            raise ValueError(f'Unrecognized feature normalization: {feature_normalize}.')
        return X

    def get_SAGE_embed(self, keep_last_aggregation=True):
        """
        In the SAGE model, the transformation of the last layer mapps the embedding into
        the size of the output. If we only want the message passing without the non-linear
        transformation, we set keep_last_aggregation=True. Otherwise, we simply output
        the embedding of the last but two layer as the represation, missing one message
        passing procedure compare with the original SAGE model. 
        """
        # initial and load the trained GNN model
        in_size = self.g.ndata['feat'].shape[1]
        out_size = self.g.num_classes
        model_conf = load_train_conf(self.h_conf_path)
        hidden_size = model_conf["hidden_size"]
        with open(self.h_model_path, 'rb') as fin:
            self.model_state = fin.read()
        self.model_state = pickle.loads(self.model_state)
        # keep the last aggregation
        if keep_last_aggregation:
            last_layer_i = model_conf["hidden_layer"]-1
            last_embed_size = hidden_size if model_conf["hidden_layer"] else in_size
            self.model_state[f'layers.{last_layer_i:d}.bias'] =\
                torch.zeros(last_embed_size).to(self.device)
            if f'layers.{last_layer_i:d}.fc_self.weight' in self.model_state:
                self.model_state[f'layers.{last_layer_i:d}.fc_self.weight'] =\
                    torch.eye(last_embed_size).to(self.device)
            self.model_state[f'layers.{last_layer_i:d}.fc_neigh.weight'] =\
                torch.eye(last_embed_size).to(self.device)
            out_size = last_embed_size

        self.GNN_model = SAGE(in_size,
                              hidden_size,
                              out_size,
                              GNN_layer=model_conf["hidden_layer"],
                              dropout=model_conf["dropout"]).to(self.device)
        self.GNN_model.load_state_dict(self.model_state)

        self.GNN_model.eval()
        with torch.no_grad():
            target = self.GNN_model.inference(self.g, self.device, 32,
                                              is_sample=False, sample_size=0,
                                              return_embed=not keep_last_aggregation)
            target = target.cpu().detach().numpy()
        return target
    
    def gen_adj_power_layerwise(self, k, norm_g): # pylint: disable=invalid-name
        """
        generate k power of adjacent matrix from dgl graph efficiently on GPU
        input:
            k: number of walks in SGC model
            norm_g: normalized_g on cpu
        output:
            scipy coo_matrix that contains the desired matrix
        """
        # GPU_MEM_SIZE = 28 * 1024 ** 3
        data_size = norm_g.shape[0]
        # batch_size = int(GPU_MEM_SIZE / (data_size * 4 * 10))
        batch_size = 8
        total_batch = int(np.ceil(data_size / batch_size))
        print(f"        Gen SGC batch size/data size/total batch:"
              f" {batch_size}/{data_size}/{total_batch}")

        # norm_g = norm_g.tocsr()
        norm_g_compute_device = scipy_coo_to_torch_sparse(norm_g.tocoo()).to(self.device)
        
        res = norm_g
        for cur_k in range(k-1):
            print(f"    Iteration {cur_k}/{k-1}.")
            all_res = []        
            max_batch_nonzero = 0
            t_start = time()
            print("", end="")
            for i in range(total_batch):
                cur_res = res[i*batch_size : min((i+1)*batch_size, data_size), :]
                with torch.no_grad():
                    cur_res_gpu = scipy_coo_to_torch_sparse(cur_res.tocoo()).to(self.device)
                    cur_res_gpu_new = torch.sparse.mm(cur_res_gpu,
                                                      norm_g_compute_device)
                    del cur_res_gpu
                    all_res.append(torch_sparse_to_scipy_coo(cur_res_gpu_new.cpu().detach()))
                    del cur_res_gpu_new
                torch.cuda.empty_cache()
                # gc.collect()
                cur_nonzero = cur_res.count_nonzero()
                if cur_nonzero > max_batch_nonzero:
                    max_batch_nonzero = cur_nonzero
                print(f"\r        nonzeros/max: {cur_nonzero}/{max_batch_nonzero}"
                      f" adj compute batch: {i+1}/{total_batch}, "
                      f"time cost {time()-t_start:.1f},"
                      f" ETA:{(time()-t_start)/(i+1)*(total_batch-i-1):.1f}",
                      end="",
                      flush=True)
            print()
            res = sps.vstack(all_res).tocsr()
        # end of k loop
        return res

    def gen_adj_power_minibatch(self, k, norm_g): # pylint: disable=invalid-name
        """
        generate k power of adjacent matrix from dgl graph efficiently on GPU
        input:
            k: number of walks in SGC model
            norm_g: normalized_g on cpu
            norm_g_compute_device: normalized g on compute device
        output:
            scipy coo_matrix that contains the desired matrix
        """
        # GPU_MEM_SIZE = 28 * 1024 ** 3
        data_size = norm_g.shape[0]
        # batch_size = int(GPU_MEM_SIZE / (data_size * 4 * 10))
        batch_size = 8
        total_batch = int(np.ceil(data_size / batch_size))
        print(f"        Gen SGC batch size/data size/total batch:"
              f" {batch_size}/{data_size}/{total_batch}")

        all_res = []
        norm_g = norm_g.tocsr()
        norm_g_compute_device = scipy_coo_to_torch_sparse(norm_g.tocoo()).to(self.device)
        max_batch_nonzero = 0
        t_start = time()
        print("", end="")
        for i in range(total_batch):
            cur_res = norm_g[i*batch_size : min((i+1)*batch_size, data_size), :]
            with torch.no_grad():
                cur_res_gpu = scipy_coo_to_torch_sparse(cur_res.tocoo()).to(self.device)
                for _ in range(k-1):
                    cur_res_gpu_new = torch.sparse.mm(cur_res_gpu.detach(), norm_g_compute_device)
                    del cur_res_gpu
                    cur_res_gpu = cur_res_gpu_new
                all_res.append(torch_sparse_to_scipy_coo(cur_res_gpu.cpu().detach()))
                del cur_res_gpu
            torch.cuda.empty_cache()
            # gc.collect()
            cur_nonzero = cur_res.count_nonzero()
            if cur_nonzero > max_batch_nonzero:
                max_batch_nonzero = cur_nonzero
            print(f"\r        nonzeros/max: {cur_nonzero}/{max_batch_nonzero}"
                  f" adj compute batch: {i+1}/{total_batch}, "
                  f"time cost {time()-t_start:.1f},"
                  f" ETA:{(time()-t_start)/(i+1)*(total_batch-i-1):.1f}",
                  end="",
                  flush=True)
        print()
        res = sps.vstack(all_res)
        return res.tocsr()
    
    def disp_states(self):
        def disp_csr_row_cnt(foo, name=""):
            cand_cnt = foo.indptr[1:] - foo.indptr[:-1]
            print(f"{name} nonzeros cnt {np.mean(cand_cnt):.1f} \u00B1 {np.std(cand_cnt):.1f}")
        # def disp_lil_row_cnt(foo, name=""):
        #     cand_cnt = [len(i) for i in foo.rows]
        #     print(f"{name} nonzeros cnt {np.mean(cand_cnt):.1f} \u00B1 {np.std(cand_cnt):.1f}")
        # theta_cand
        disp_csr_row_cnt(self.theta_cand, name="theta_cand")
        disp_csr_row_cnt(self.h_init_theta, name="h_init_theta")

class PlainLoader:
    """
    Basic data loader for array of features and labels.
    """
    def __init__(self, feature, label, batch_size, sample_idx, is_shuffle=True):
        self.feature = feature
        self.label = label
        self.batch_size = batch_size
        self.sample_idx = sample_idx
        self.is_shuffle = is_shuffle

        self.sample_size = self.sample_idx.shape[0]
        self.max_iteration = int(np.ceil(self.sample_size / self.batch_size))

        self.__iter_cnt, self.__iter_idx = None, None

    def __iter__(self):
        self.__iter_cnt = 0
        if self.is_shuffle:
            self.__iter_idx = self.sample_idx[torch.randperm(self.sample_size)]
        else:
            self.__iter_idx = self.sample_idx
        return self

    def __next__(self):
        if self.__iter_cnt < self.max_iteration:
            cur_idx = self.__iter_idx[self.__iter_cnt*self.batch_size:
                                      (self.__iter_cnt+1)*self.batch_size]
            cur_feature = self.feature[cur_idx]
            cur_label = self.label[cur_idx]
            self.__iter_cnt += 1
            return cur_feature, cur_label
        else:
            raise StopIteration

class PlainListLoader:
    """
    Basic data loader for list arrays. The loop and batch will be carried over the first dimension
    of each of the objects given in foo_list. 
    """
    def __init__(self, foo_list, batch_size, sample_idx, is_shuffle=True):
        self.foo_list = foo_list
        self.batch_size = batch_size
        self.sample_idx = sample_idx
        self.is_shuffle = is_shuffle

        self.sample_size = self.sample_idx.shape[0]
        self.max_iteration = int(np.ceil(self.sample_size / self.batch_size))

        self.__iter_cnt, self.__iter_idx = None, None

    def __iter__(self):
        self.__iter_cnt = 0
        if self.is_shuffle:
            self.__iter_idx = self.sample_idx[torch.randperm(self.sample_size)]
        else:
            self.__iter_idx = self.sample_idx
        return self

    def __next__(self):
        if self.__iter_cnt < self.max_iteration:
            cur_idx = self.__iter_idx[self.__iter_cnt*self.batch_size:
                                      (self.__iter_cnt+1)*self.batch_size]
            res = []

            for each in self.foo_list:
                res.append(each[cur_idx])
            self.__iter_cnt += 1
            return res
        else:
            raise StopIteration

class InfiniteLooper:
    """
    InfiniteLooploader
    """
    def __init__(self, label, batch_size, is_shuffle=True):
        self.label = label
        self.batch_size = batch_size
        self.is_shuffle = is_shuffle

        self.sample_size = len(label)
        self.max_iteration = int(np.ceil(self.sample_size / max(self.batch_size, 1)))

        self.__iter_cnt, self.__iter_idx = 0, []

    def __iter__(self):
        self.__iter_cnt = 0
        if self.is_shuffle:
            random.shuffle(self.label)
        return self

    def __next__(self):
        if self.__iter_cnt >= self.max_iteration-1:
            self.__iter__()
        cur_label = self.label[self.__iter_cnt*self.batch_size: 
                               (self.__iter_cnt+1)*self.batch_size]
        self.__iter_cnt += 1
        return cur_label
